#include "opencv2/core/mat.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"

#include "dnn_node/dnn_node.h"
#include "dnn_node/util/image_proc.h"
#include "sensor_msgs/msg/image.hpp"

using namespace hobot::dnn_node;

class BodyDetNode : public DnnNode {
 public:
  BodyDetNode(
  const std::string & node_name = "body_det",
  const rclcpp::NodeOptions & options = rclcpp::NodeOptions()) :
    DnnNode(node_name, options) {
    // 获取模型输入图片的尺寸，包括图片的宽model_input_width_和高model_input_height ，用于模型前处理
    if (Init() != 0 ||
      GetModelInputSize(0, model_input_width_, model_input_height_) < 0) {
      RCLCPP_ERROR(rclcpp::get_logger("example"), "Node init fail!");
    }

  //创建图片消息的subscriber，订阅的topic为”/image_raw”，消息类型为sensor_msgs::msg::Image。订阅到的图片用于算法模型推理
  ros_img_subscription_ = this->create_subscription<sensor_msgs::msg::Image>(
        "/image_raw", 10, std::bind(&BodyDetNode::FeedImg, this,
                  std::placeholders::_1));
  }
  ~BodyDetNode() override {}

  void FeedImg(const sensor_msgs::msg::Image::ConstSharedPtr msg);

 protected:
  int SetNodePara() override {
    if (!dnn_node_para_ptr_) return -1;
    //指定模型推理使用的模型文件名和模型名
    dnn_node_para_ptr_->model_file = "config/multitask_body_kps_960x544.hbm";
    dnn_node_para_ptr_->model_name = "multitask_body_kps_960x544";

    //指定模型输出的人体框的解析方法，其中人体框输出索引为box_output_index_
    //使用的解析方法为hobot_dnn预定义的检测框解析方法FaceHandDetectionOutputParser。
    std::shared_ptr<OutputParser> box_out_parser =
        std::make_shared<FaceHandDetectionOutputParser>();
    dnn_node_para_ptr_->output_parsers_.emplace_back(
      std::make_pair(box_output_index_, box_out_parser)
    );
    return 0;
  }

  int PostProcess(const std::shared_ptr<DnnNodeOutput> &node_output)
    override;

 private:
  int model_input_width_ = -1;
  int model_input_height_ = -1;
  const int32_t box_output_index_ = 1;
  sensor_msgs::msg::Image::ConstSharedPtr img_msg_;
  rclcpp::Subscription<sensor_msgs::msg::Image>::ConstSharedPtr
    ros_img_subscription_ = nullptr;
};

//输出模型结果，并将结果渲染到图片后保存在本地
int BodyDetNode::PostProcess(
  const std::shared_ptr<DnnNodeOutput> &node_output) {
  if (node_output->outputs.empty() ||
    static_cast<int32_t>(node_output->outputs.size()) < box_output_index_) {
    RCLCPP_ERROR(rclcpp::get_logger("example"), "Invalid outputs");
    return -1;
  }

  auto *filter2d_result =
    dynamic_cast<Filter2DResult *>(node_output->outputs[box_output_index_].get());
  if (!filter2d_result) return -1;

  std::stringstream ss;
  ss << "img encoding: " << img_msg_->encoding
  << ", stamp: " << img_msg_->header.stamp.sec << "," << img_msg_->header.stamp.nanosec
  << "\nout box size: " << filter2d_result->boxes.size() << "\n";

  cv::Mat nv12(img_msg_->height * 3 / 2, img_msg_->width, CV_8UC1,
  const_cast<char*>(reinterpret_cast<const char*>(img_msg_->data.data())));
  cv::Mat bgr;
  cv::cvtColor(nv12, bgr, CV_YUV2BGR_NV12);

  for (auto &rect : filter2d_result->boxes) {
    ss << "rect: " << rect.left << " " << rect.top
        << " " << rect.right << " " << rect.bottom << "\n";
    // 图片渲染
    cv::rectangle(bgr,
        cv::Point(rect.left, rect.top), cv::Point(rect.right, rect.bottom),
        cv::Scalar(255, 0, 0), 3);
  }

  std::string result_image = "render_" +
    std::to_string(img_msg_->header.stamp.sec) + "." +
    std::to_string(img_msg_->header.stamp.nanosec) + ".jpg";
  ss << "Render img to file: " << result_image;
  RCLCPP_INFO(rclcpp::get_logger("example"), "%s", ss.str().c_str());

  cv::imwrite(result_image, bgr);
  return 0;
}

//将nv12格式的图片转成模型输入的数据类型DNNInput后，输入给推理任务
void BodyDetNode::FeedImg(const sensor_msgs::msg::Image::ConstSharedPtr img_msg) {
  if (!img_msg) return;
  if ("nv12" != img_msg->encoding) {
    RCLCPP_ERROR(rclcpp::get_logger("example"), "Only support nv12 img encoding!");
    return;
  }
  img_msg_ = img_msg;

  // 创建模型输入数据
  auto inputs = std::vector<std::shared_ptr<DNNInput>>{
    ImageProc::GetNV12PyramidFromNV12Img(
      reinterpret_cast<const char*>(img_msg->data.data()),
      img_msg->height, img_msg->width, model_input_height_, model_input_width_)};
  // 运行推理，DnnNode基类中定义并实现的启动推理接口。
  Run(inputs);
}

int main(int argc, char** argv) {
  rclcpp::init(argc, argv);
  rclcpp::spin(std::make_shared<BodyDetNode>());
  rclcpp::shutdown();
  return 0;
}
